import argparse, os, pickle, math
import numpy as np

## Set these to the same as used in experiment
num_words = 60
N_test = 200

# Set confidence interval for metric
confidence = 95


correct_prior = 'ctm'

confidence_val = {90: 1.645, 95: 1.96, 99: 2.576}
z = confidence_val[confidence]

def top_k_ovelap(y_true, y_pred, k):
    assert y_true.shape == y_pred.shape
    overlap_rate = np.zeros(len(y_true))
    N, C = y_true.shape
    if k > C:
        print('error')
        return -1
    for i in range(N):
        overlap_rate[i] = len(set(np.argsort(y_pred[i])[-k:])& set(np.argsort(y_true[i])[-k:]))/k
    mean = overlap_rate.mean()
    std = overlap_rate.std()
    return round(mean, 4), round(z*std/math.sqrt(len(overlap_rate)), 4)
    

if __name__ == '__main__':
    prior_list = ['lda', 'ctm', 'pam']
    prior_list.remove(correct_prior)
    prior_list = [correct_prior] + prior_list

    print('Calculating TV distance between true topic posterior and MCMC-inferred posterior assuming some prior...')
    # TV 
    
    for alpha in range(1,10,2):
        print('alpha =', alpha)
        if f'MCMC_Posterior_{N_test}.pkl' in os.listdir(f'data/alpha{alpha}.0_{num_words}wordsdoc'):
            with open(f'data/alpha{alpha}.0_{num_words}wordsdoc/MCMC_Posterior_{N_test}.pkl', 'rb') as f:
                true_posterior = pickle.load(f)
            
            for prior_type in prior_list:
                if prior_type != correct_prior:
                    mcmc_posterior_filename = f'data/alpha{alpha}.0_{num_words}wordsdoc/{prior_type}_prior_MCMC_Posterior_{N_test}.pkl'
                    with open(mcmc_posterior_filename, 'rb') as f:
                        L = pickle.load(f)
                    
                    assert L.shape == true_posterior.shape

                    TV_data = np.sum(np.abs(L-true_posterior)/2, axis=1)
                    mean_TV = TV_data.mean()
                    std_TV = TV_data.std()
                    r = z*std_TV/math.sqrt(len(TV_data))
                    print(f'{prior_type}: {round(mean_TV,4)} ± {round(r,4)}')
        print()
    

    print('Calculating major topic overlap rate between true topic prior and MCMC-inferred posterior assuming some prior...')
    for alpha in range(1,10,2):
        with open(f'data/alpha{alpha}.0_{num_words}wordsdoc/test_topics_dist.pkl', 'rb') as f:
            true_prior = pickle.load(f)
        true_prior = np.array(true_prior)
        print('alpha =', alpha)
        for prior_type in prior_list:
            if f'{prior_type}_prior_MCMC_Posterior_{N_test}.pkl' in os.listdir(f'data/alpha{alpha}.0_{num_words}wordsdoc'):
                posterior_filename = f'data/alpha{alpha}.0_{num_words}wordsdoc/{prior_type}_prior_MCMC_Posterior_{N_test}.pkl'
                with open(posterior_filename, 'rb') as f:
                    L = pickle.load(f)
                
                assert L.shape == true_prior.shape

                mean_overlap, r_overlap = top_k_ovelap(true_prior, L, 2)
                if prior_type == correct_prior:
                    print(f'{prior_type} (correct): {round(mean_overlap, 4)} ± {round(r_overlap, 4)}')
                else:
                    print(f'{prior_type} (mis-specified): {round(mean_overlap, 4)} ± {round(r_overlap, 4)}')
            
        print()
    

    print('Calculating major topic overlap rate between true topic prior and Variational Inference-inferred posterior assuming some prior...')
    # Major topic recovery
    for alpha in range(1,10,2):
        with open(f'data/alpha{alpha}.0_{num_words}wordsdoc/test_topics_dist.pkl', 'rb') as f:
            true_prior = pickle.load(f)
        true_prior = np.array(true_prior)
        print('alpha =', alpha)
        for prior_type in prior_list:
            if f'{prior_type}_prior_VI_Posterior_{N_test}.pkl' in os.listdir(f'data/alpha{alpha}.0_{num_words}wordsdoc'):
                posterior_filename = f'data/alpha{alpha}.0_{num_words}wordsdoc/{prior_type}_prior_VI_Posterior_{N_test}.pkl'
                with open(posterior_filename, 'rb') as f:
                    L = pickle.load(f)
                
                assert L.shape == true_prior.shape

                mean_overlap, r_overlap = top_k_ovelap(true_prior, L, 2)
                if prior_type == correct_prior:
                    print(f'{prior_type} (correct): {round(mean_overlap, 4)} ± {round(r_overlap, 4)}')
                else:
                    print(f'{prior_type} (mis-specified): {round(mean_overlap, 4)} ± {round(r_overlap, 4)}')
        print()


            
            